# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Benchmark for Keras text vectorization preprocessing layer's adapt method."""

import collections
import itertools
import random
import string
import time

import numpy as np
import tensorflow.compat.v2 as tf

import keras
from keras.layers.preprocessing import index_lookup

tf.compat.v1.enable_v2_behavior()


# word_gen creates random sequences of ASCII letters (both lowercase and upper).
# The number of unique strings is ~2,700.
def word_gen():
    for _ in itertools.count(1):
        yield "".join(random.choice(string.ascii_letters) for i in range(2))


def get_top_k(dataset, k):
    """Python implementation of vocabulary building using a defaultdict."""
    counts = collections.defaultdict(int)
    for tensor in dataset:
        data = tensor.numpy()
        for element in data:
            counts[element] += 1
    sorted_vocab = [
        k
        for k, _ in sorted(
            counts.items(), key=lambda item: item[1], reverse=True
        )
    ]
    if len(sorted_vocab) > k:
        sorted_vocab = sorted_vocab[:k]
    return sorted_vocab


class BenchmarkAdapt(tf.test.Benchmark):
    """Benchmark adapt."""

    def run_numpy_implementation(self, num_elements, batch_size, k):
        """Test the python implementation."""
        ds = tf.data.Dataset.from_generator(
            word_gen, tf.string, tf.TensorShape([])
        )
        batched_ds = ds.take(num_elements).batch(batch_size)
        input_t = keras.Input(shape=(), dtype=tf.string)
        layer = index_lookup.IndexLookup(
            max_tokens=k,
            num_oov_indices=0,
            mask_token=None,
            oov_token="OOV",
            dtype=tf.string,
        )
        _ = layer(input_t)
        num_repeats = 5
        starts = []
        ends = []
        for _ in range(num_repeats):
            starts.append(time.time())
            vocab = get_top_k(batched_ds, k)
            layer.set_vocabulary(vocab)
            ends.append(time.time())
        avg_time = np.mean(np.array(ends) - np.array(starts))
        return avg_time

    def bm_adapt_implementation(self, num_elements, batch_size, k):
        """Test the KPL adapt implementation."""
        ds = tf.data.Dataset.from_generator(
            word_gen, tf.string, tf.TensorShape([])
        )
        batched_ds = ds.take(num_elements).batch(batch_size)
        input_t = keras.Input(shape=(), dtype=tf.string)
        layer = index_lookup.IndexLookup(
            max_tokens=k,
            num_oov_indices=0,
            mask_token=None,
            oov_token="OOV",
            dtype=tf.string,
        )
        _ = layer(input_t)
        num_repeats = 5
        starts = []
        ends = []
        for _ in range(num_repeats):
            starts.append(time.time())
            layer.adapt(batched_ds)
            ends.append(time.time())
        avg_time = np.mean(np.array(ends) - np.array(starts))
        name = "index_lookup_adapt|%s_elements|vocab_size_%s|batch_%s" % (
            num_elements,
            k,
            batch_size,
        )
        baseline = self.run_numpy_implementation(num_elements, batch_size, k)
        extras = {
            "numpy implementation baseline": baseline,
            "delta seconds": (baseline - avg_time),
            "delta percent": ((baseline - avg_time) / baseline) * 100,
        }
        self.report_benchmark(
            iters=num_repeats, wall_time=avg_time, extras=extras, name=name
        )

    def benchmark_vocab_size_by_batch(self):
        for vocab_size in [100, 1000, 10000, 100000, 1000000]:
            for batch in [1, 16, 2048]:
                self.bm_adapt_implementation(
                    vocab_size, batch, int(vocab_size / 10)
                )


if __name__ == "__main__":
    tf.test.main()
